import os
from typing import List, Dict, Tuple, Optional, Union, Set

cpp_def = open('ir_v2_defs.h', 'w')
fwd_file = open('ir_v2_fwd.h', 'w')
c_def = open('ir_v2_api.h', 'w')
c_api_impl = open('../../../src/ir_v2/ir_v2_api.cpp', 'w')
cpp_api_impl = open('../../../src/ir_v2/ir_v2_defs.cpp', 'w')
func_table = []


MAP_FFI_TYPE_GET = {
    'bool': 'bool',
    'uint8_t': 'uint8_t',
    'uint16_t': 'uint16_t',
    'uint32_t': 'uint32_t',
    'uint64_t': 'uint64_t',

    'const Type*': 'const Type *',
    'luisa::string': 'Slice<const char>',
    'luisa::vector<uint8_t>': 'Slice<uint8_t>',
    'luisa::vector<Node*>': 'Slice<Node*>',
    'luisa::vector<PhiIncoming>': 'Slice<PhiIncoming>',
    'luisa::vector<SwitchCase>': 'Slice<SwitchCase>',
    'luisa::vector<Binding>': 'Slice<Binding>',
    'luisa::vector<Callable*>': 'Slice<Callable*>',
    'luisa::vector<CallableModule*>': 'Slice<CallableModule*>',
    'luisa::vector<Module*>': 'Slice<Module*>',
    'luisa::vector<KernelModule*>': 'Slice<KernelModule*>',
    'luisa::shared_ptr<CallableModule>': 'CallableModule*',
    'luisa::shared_ptr<CpuExternFn>': 'CpuExternFn*',
    'Func': 'const CFunc*',
    'Node*': 'Node*',
    'BasicBlock*': 'BasicBlock*',
    'const BasicBlock*': 'const BasicBlock*',
}
MAP_FFI_TYPE_SET = {
     'bool': 'bool',
    'uint8_t': 'uint8_t',
    'uint16_t': 'uint16_t',
    'uint32_t': 'uint32_t',
    'uint64_t': 'uint64_t',

    'const Type*': 'const Type *',
    'luisa::string': 'Slice<const char>',
    'luisa::vector<uint8_t>': 'Slice<uint8_t>',
    'luisa::vector<Node*>': 'Slice<Node*>',
    'luisa::vector<PhiIncoming>': 'Slice<PhiIncoming>',
    'luisa::vector<SwitchCase>': 'Slice<SwitchCase>',
    'luisa::vector<Binding>': 'Slice<Binding>',
    'luisa::vector<Callable*>': 'Slice<Callable*>',
    'luisa::vector<CallableModule*>': 'Slice<CallableModule*>',
    'luisa::vector<Module*>': 'Slice<Module*>',
    'luisa::vector<KernelModule*>': 'Slice<KernelModule*>',
    'luisa::shared_ptr<CallableModule>': 'CallableModule*',
    'luisa::shared_ptr<CpuExternFn>': 'CpuExternFn*',
    'Func': 'CFunc',
    'Node*': 'Node*',
    'BasicBlock*': 'BasicBlock*',
    'const BasicBlock*': 'const BasicBlock*',
}


def to_screaming_snake_case(name: str):
    out = ''
    for c in name:
        if c.isupper() and out != '':
            out += '_'
        out += c.upper()
    return out


print('#pragma once', file=cpp_def)
print('''#pragma once
// if msvc
#ifdef _MSC_VER
#pragma warning( disable : 4190)
#endif
''', file=fwd_file)
print('#include <cstdint>', file=fwd_file)
print('#include <luisa/core/dll_export.h>', file=fwd_file)
print('#ifndef BINDGEN', file=fwd_file)
print('#include <array>', file=fwd_file)
print('#include <luisa/core/stl/memory.h>', file=fwd_file)
print('#include <luisa/core/stl/string.h>', file=fwd_file)
print('#include <luisa/core/stl/vector.h>', file=fwd_file)
print('#endif', file=fwd_file)
print('#pragma once', file=c_def)
print('/// This file is generated by gen_ir_def.py', file=cpp_def)
print('#include <type_traits>', file=cpp_def)
print('#include <luisa/ir_v2/ir_v2_fwd.h>', file=cpp_def)
print('#include <luisa/core/logging.h>', file=cpp_def)
print('#include <luisa/ast/type.h>', file=cpp_def)
print('#include <luisa/ast/type_registry.h>', file=cpp_def)
print('namespace luisa::compute::ir_v2 {', file=cpp_def)
print(
    'namespace luisa::compute { class Type;} namespace luisa::compute::ir_v2 {', file=fwd_file)
print('#include <luisa/ir_v2/ir_v2_fwd.h>', file=c_def)
print('#include <luisa/ir_v2/ir_v2_defs.h>', file=c_api_impl)
print('#include <luisa/ir_v2/ir_v2.h>', file=c_api_impl)
print('#include <luisa/ir_v2/ir_v2_api.h>', file=c_api_impl)
print('''
namespace luisa::compute::ir_v2 {
/** 
* <div rustbindgen nodebug></div>
*/
template<class T>
struct Slice {
    T *data = nullptr;
    size_t len = 0;
    constexpr Slice() noexcept = default;
    constexpr Slice(T *data, size_t len) noexcept : data(data), len(len) {}
#ifndef BINDGEN
    // construct from array
    template<size_t N>
    constexpr Slice(T (&arr)[N]) noexcept : data(arr), len(N) {}
    // construct from std::array
    template<size_t N>
    constexpr Slice(std::array<T, N> &arr) noexcept : data(arr.data()), len(N) {}
    // construct from luisa::vector
    constexpr Slice(luisa::vector<T> &vec) noexcept : data(vec.data()), len(vec.size()) {}
    // construct from luisa::span
    constexpr Slice(luisa::span<T> &span) noexcept : data(span.data()), len(span.size()) {}
    // construct from luisa::string
    constexpr Slice(luisa::string &str) noexcept : data(str.data()), len(str.size()) {
        static_assert(std::is_same_v<T, char> || std::is_same_v<T, const char>);
    }
    luisa::vector<T> to_vector() const noexcept {
        return luisa::vector<T>(data, data + len);
    }
    luisa::string to_string() const noexcept {
        static_assert(std::is_same_v<T, char> || std::is_same_v<T, const char>);
        return luisa::string(data, len);
    }
#endif 
};    
}
''', file=c_def)
print('namespace luisa::compute::ir_v2 {', file=c_def)
print('namespace luisa::compute::ir_v2 {', file=c_api_impl)
print('''
struct Node;
class BasicBlock;
struct CallableModule;
struct Module;
struct KernelModule; 
class Pool;
template<class T>
struct Slice;
struct CInstruction;
struct CFunc;
struct CBinding;
// Don't touch!! These typedef are for bindgen
typedef const Node *NodeRef;
typedef Node *NodeRefMut;
typedef const BasicBlock *BasicBlockRef;
typedef BasicBlock *BasicBlockRefMut;
/**
* <div rustbindgen nocopy></div>
*/
typedef const CallableModule *CallableModuleRef;
/**
* <div rustbindgen nocopy></div>
*/
typedef CallableModule *CallableModuleRefMut;
/**
* <div rustbindgen nocopy></div>
*/
typedef const Module *ModuleRef;
/**
* <div rustbindgen nocopy></div>
*/
typedef Module *ModuleRefMut;
typedef const KernelModule *KernelModuleRef;
/**
* <div rustbindgen nocopy></div>
*/
typedef KernelModule *KernelModuleRefMut;
/**
* <div rustbindgen nocopy></div>
*/
typedef const Pool *PoolRef;
/**
* <div rustbindgen nocopy></div>
*/
typedef Pool *PoolRefMut;
typedef const Type *TypeRef;
enum class RustyTypeTag {
    Bool,//BOOL,
    Int8,//INT8,
    Uint8,//UINT8,
    Int16,//INT16,
    Uint16,//UINT16,
    Int32,//INT32,
    Uint32,//UINT32,
    Int64,//INT64,
    Uint64,//UINT64,
    Float16,//FLOAT16,
    Float32,//FLOAT32,
    Float64,//FLOAT64,

    Vector,//VECTOR,
    Matrix,//MATRIX,

    Array,//,ARRAY,
    Struct,//,STRUCTURE,

    __HIDDEN_BUFFER,
    __HIDDEN_TEXTURE,
    __HIDDEN_BINDLESS_ARRAY,
    __HIDDEN_ACCEL,

    Custom,//CUSTOM
};
      
/**
* <div rustbindgen nocopy></div>
*/
class IrBuilder;
/**
* <div rustbindgen nocopy></div>
*/
typedef const IrBuilder *IrBuilderRef;
/**
* <div rustbindgen nocopy></div>
*/
typedef IrBuilder *IrBuilderRefMut;

    
''', file=fwd_file)
print('#include <luisa/ir_v2/ir_v2.h>', file=cpp_api_impl)
print('namespace luisa::compute::ir_v2 {', file=cpp_api_impl)

class Item:
    def __init__(self, name, base,fields: List[Tuple[str, str]], comment=None,no_copy=False) -> None:
        self.cpp_src = ''
        self.name = name
        self.fields = fields
        self.comment = comment
        self.base = base
        self.tag = to_screaming_snake_case(name)
        self.tag_rs = str(name)
        self.no_copy = no_copy

    def gen(self):
        out = 'public:\n'
        for field in self.fields:
            out += '    {} {}{{}};\n'.format(field[0], field[1])

        out += f'    {self.name}() = default;\n'
        if len(self.fields) > 0:
            out += f'    {self.name}('
            for i, field in enumerate(self.fields):
                if i != 0:
                    out += ', '
                out += f'{field[0]} {field[1]}'
            out += ') : '
            for i, field in enumerate(self.fields):
                if i != 0:
                    out += ', '
                out += f'{field[1]}(std::move({field[1]}))'
            out += ' {}\n'
        out += self.cpp_src
        return out

    def gen_c_api(self):
        if len(self.fields) == 0:
            return ''
        # gen xx_field() -> xx
        for f in self.fields:
            fname = f'{self.name}_{f[1]}'
            fsig = f'{MAP_FFI_TYPE_GET[f[0]]} (*{fname})({self.name} *self)'
            func_table.append((fname, fsig))
            print(
                f'static {MAP_FFI_TYPE_GET[f[0]]} {fname}({self.name} *self) {{', file=c_api_impl)
            if 'shared_ptr' in f[0]:
                print(
                    '    return self->{}.get();'.format(f[1]), file=c_api_impl)
            elif f[0] =='Func':
                print('    return reinterpret_cast<const CFunc*>(&self->{});'.format(f[1]), file=c_api_impl)
            else:
                print('    return self->{};'.format(f[1]), file=c_api_impl)
            print('}', file=c_api_impl)

        # set xx_field(xx)
        for f in self.fields:
            fname = f'{self.name}_set_{f[1]}'
            fsig = f'void (*{fname})({self.name} *self, {MAP_FFI_TYPE_SET[f[0]]} value)'
            func_table.append((fname, fsig))
            print(
                f'static void {fname}({self.name} *self, {MAP_FFI_TYPE_SET[f[0]]} value) {{', file=c_api_impl)
            if 'shared_ptr' in f[0]:
                print(
                    '    self->{0} = luisa::static_pointer_cast<std::decay_t<decltype(self->{0})>::element_type>(value->shared_from_this());'.format(f[1]), file=c_api_impl)
            elif f[0].startswith('luisa::vector'):
                print(
                    '    self->{} = value.to_vector();'.format(f[1]), file=c_api_impl)
            elif f[0].startswith('luisa::string'):
                print(
                    '    self->{} = value.to_string();'.format(f[1]), file=c_api_impl)
            elif f[0] == 'Func':
                print(
                    f'    self->{f[1]} = std::move(*reinterpret_cast<Func*>(&value));', file=c_api_impl)
            else:
                print('    self->{} = value;'.format(f[1]), file=c_api_impl)
            print('}', file=c_api_impl)

        # gen ctor
        fname = f'{self.name}_new'

        fsig = f'C{self.base} (*{fname})(Pool *pool'
        for f in self.fields:
            fsig += f', {MAP_FFI_TYPE_SET[f[0]]} {f[1]}'
        fsig += ')'
        func_table.append((fname, fsig))
        print(f'static C{self.base} {fname}(Pool *pool', end='',
              file=c_api_impl)
        for f in self.fields:
            print(f', {MAP_FFI_TYPE_SET[f[0]]} {f[1]}', end='', file=c_api_impl)
        print(') {', file=c_api_impl)
        print(f'   auto data = luisa::unique_ptr<{self.name}>();', file=c_api_impl)        
        for f in self.fields:
            print(f'    {self.name}_set_{f[1]}(data.get(), {f[1]});', file=c_api_impl)
        print(f'   auto tag = {self.name}::static_tag();', file=c_api_impl)
        print(f'   auto cobj = C{self.base}{{}};', file=c_api_impl)
        print(f'   auto obj = {self.base}(tag, std::move(data));', file=c_api_impl)
        print(f'   std::memcpy(&cobj, &obj, sizeof(C{self.base}));', file=c_api_impl)
        print(f'   (void)obj.steal();', file=c_api_impl)
        print('    return cobj;', file=c_api_impl)
        print('}', file=c_api_impl)


class Instruction(Item):
    def __init__(self, name, fields=None, cpp_src=None, **kwargs) -> None:
        if fields is None:
            fields = []
        super().__init__(name, 'Instruction', fields, **kwargs)
        self.name += 'Inst'
        if cpp_src is not None:
            self.cpp_src = cpp_src


class Func(Item):
    def __init__(self, name, fields=None, side_effects=False, **kwargs) -> None:
        if fields is None:
            fields = []
        super().__init__(name, 'Func', fields, **kwargs)
        self.name += 'Fn'
        self.side_effects = side_effects



class CppType:
    def __init__(self) -> None:
        pass


class CppAtom(CppType):
    def __init__(self, name) -> None:
        super().__init__()
        self.name = name


class CppList(CppType):
    def __init__(self, ty: CppType) -> None:
        super().__init__()
        self.ty = ty


class CppString(CppType):
    def __init__(self) -> None:
        super().__init__()


class CppPointer(CppType):
    def __init__(self, ty: CppType) -> None:
        super().__init__()
        self.ty = ty


class CppSharedPtr(CppType):
    def __init__(self, ty: CppType) -> None:
        super().__init__()
        self.ty = ty


def parse_cpp_type(s: str) -> CppType:
    s = s.strip()
    if s.startswith('luisa::vector'):
        s = s[len('luisa::vector'):]
        assert s.startswith('<')
        assert s.endswith('>')
        return CppList(parse_cpp_type(s[1:-1]))
    if s == 'luisa::string':
        return CppString()
    if s.endswith('*'):
        return CppPointer(parse_cpp_type(s[:-1]))
    return CppAtom(s)


def gen_adt(adt: str, cpp_src: str, variants: List[Item]):
    # gen cpp
    print('struct {};'.format(adt), file=fwd_file)
    print('struct {}Data;'.format(adt), file=fwd_file)
    print(f'typedef const C{adt}* {adt}Ref;', file=fwd_file)
    print(f'typedef C{adt}* {adt}RefMut;', file=fwd_file)
    

    print(f'    enum class {adt}Tag : unsigned int {{', file=fwd_file)
    for variant in variants:
        print('        {},'.format(
            variant.tag), file=fwd_file)
    print('    };', file=fwd_file)

    print(f'    enum class Rusty{adt}Tag : unsigned int {{', file=fwd_file)
    for variant in variants:
        print('        {},'.format(
            variant.tag_rs), file=fwd_file)
    print('    };', file=fwd_file)

    print(f'    inline const char* tag_name({adt}Tag tag) {{', file=fwd_file)
    print(f'        switch(tag) {{',  file=fwd_file)
    for variant in variants:
        print(f'        case {adt}Tag::{variant.tag}: return "{variant.name}";', file=fwd_file)
    print('}', file=fwd_file)
    print('return "unknown";', file=fwd_file)
    print('}', file=fwd_file)



    print(f'struct LC_IR_API {adt}Data {{ ', file=fwd_file)
    print('#ifndef BINDGEN', file=fwd_file)
    print(f'    virtual {adt}Tag tag() const noexcept = 0;', file=fwd_file)
    print('    virtual ~{}Data() = default;'.format(adt), file=fwd_file)
    print('#endif', file=fwd_file)
    print('};', file=fwd_file)
    for variant in variants:
        if len(variant.fields) > 0:
            print('struct {};'.format(variant.name), file=fwd_file)
            if variant.no_copy:
                print('''/**
* <div rustbindgen nocopy></div>
*/''', file=fwd_file)
            print(f'typedef const {variant.name}* {variant.name}Ref;', file=fwd_file)
            if variant.no_copy:
                print('''/**
* <div rustbindgen nocopy></div>
*/''', file=fwd_file)
            print(f'typedef {variant.name}* {variant.name}RefMut;', file=fwd_file)
    print('struct LC_IR_API {} {{'.format(adt), file=cpp_def)
    print('    luisa::unique_ptr<{}Data> _data;'.format(adt), file=cpp_def)
    print('     {}Tag _tag;'.format(adt), file=cpp_def)
    print('public:', file=cpp_def)
    print(f'   explicit {adt}({adt}Tag tag) : _data(luisa::unique_ptr<{adt}Data>()), _tag(tag) {{}}', file=cpp_def)
    print('    explicit {}({}Tag tag, luisa::unique_ptr<{}Data> data) : _data(std::move(data)), _tag(tag) {{'.format(
        adt, adt, adt), file=cpp_def)
    print(f'        LUISA_ASSERT(tag == _data->tag(), "Mismatched tag!!!");', file=cpp_def)
    print('    }', file=cpp_def)
    print('    typedef {}Tag Tag;'.format(adt), file=cpp_def)
    for variant in variants:
        if len(variant.fields) > 0:
            print(f'    explicit {adt}({variant.name} v);', file=cpp_def)
            print(f'    {adt}::{adt}({variant.name} v):_data(luisa::make_unique<{variant.name}>(std::move(v))), _tag({variant.name}::static_tag()) {{}}', file=cpp_api_impl)
    print('    [[nodiscard]] Tag tag() const noexcept {', file=cpp_def)
    print('        return _tag;', file=cpp_def)
    print('    }', file=cpp_def)
    print('    [[nodiscard]] bool isa(Tag tag)const noexcept {{'.format(
        adt), file=cpp_def)
    print('        return this->tag() == tag;', file=cpp_def)
    print('    }', file=cpp_def)
    print('     template<class T> requires std::is_base_of_v<{}Data, T>  [[nodiscard]] bool isa()const noexcept {{'.format(
        adt), file=cpp_def)
    print('        return this->isa(T::static_tag());', file=cpp_def)
    print('    }', file=cpp_def)
    print('    template<class T> requires std::is_base_of_v<{}Data, T> [[nodiscard]]  T* as() {{'.format(
        adt), file=cpp_def)
    print('        return isa(T::static_tag()) ? static_cast<T*>(_data.get()) : nullptr;', file=cpp_def)
    print('    }', file=cpp_def)
    print('    template<class T> requires std::is_base_of_v<{}Data, T> [[nodiscard]] const T* as() const {{'.format(
        adt), file=cpp_def)
    print('        return isa(T::static_tag()) ? static_cast<const T*>(_data.get()) : nullptr;', file=cpp_def)
    print('    }', file=cpp_def)
    print('    ', cpp_src, file=cpp_def)
    print(f'    [[nodiscard]] {adt}Data * steal() noexcept {{ ', file=cpp_def)
    print('        return _data.release();', file=cpp_def)
    print('    }', file=cpp_def)
    print('};', file=cpp_def)
    print('static_assert(sizeof({}) == 16);'.format(adt), file=cpp_def)
    print('static_assert(sizeof(luisa::unique_ptr<{}Data>) == 8);'.format(adt), file=cpp_def)

    for variant in variants:
        if len(variant.fields) == 0:
            continue
        print('struct LC_IR_API {} : public {}Data {{'.format(
            variant.name, adt), file=cpp_def)
        print('public:', file=cpp_def)
        print('    typedef {}Tag Tag;'.format(adt), file=cpp_def)
        print('    [[nodiscard]] Tag tag() const noexcept override {', file=cpp_def)
        print('        return static_tag();', file=cpp_def)
        print('    }', file=cpp_def)
        print('    static constexpr Tag static_tag() noexcept {', file=cpp_def)
        print('        return Tag::{};'.format(
            variant.tag), file=cpp_def)
        print('    }', file=cpp_def)
        print('    ', variant.gen(), file=cpp_def)
        print('};', file=cpp_def)

    # gen c api
    print(f'/**\n* <div rustbindgen nocopy></div>\n*/\nstruct C{adt} {{ void *data; {adt}Tag tag; }};', file=c_def)
    print(f'static_assert(sizeof(C{adt}) == 16);', file=c_def)
    for variant in variants:
        # print('extern "C" LC_IR_API {1} * lc_ir_v2_{0}_as_{1}({0} *self);'.format(
        #     adt, variant.name), file=c_def)
        # print('extern "C" LC_IR_API {1} * lc_ir_v2_{0}_as_{1}({0} *self) {{'.format(
        #     adt, variant.name), file=c_api_impl)
        # print('    return self->as<{0}>();'.format(variant.name),
        #       file=c_api_impl)
        # print('}', file=c_api_impl)
        if len(variant.fields) > 0:
            fname = f'{adt}_as_{variant.name}'
            fsig = f'{variant.name} *(*{fname})(C{adt} *self)'
            func_table.append((fname, fsig))
            print(f'static {variant.name} *{fname}(C{adt} *self) {{',
                file=c_api_impl)
            print('    return reinterpret_cast<{1}*>(self)->as<{0}>();'.format(variant.name, adt),
                file=c_api_impl)
            print('}', file=c_api_impl)
    

    fname = f'{adt}_tag'
    fsig = f'Rusty{adt}Tag (*{fname})(const C{adt} *self)'
    func_table.append((fname, fsig))
    print(f'static Rusty{adt}Tag {fname}(const C{adt} *self) {{', file=c_api_impl)
    print(f'    return static_cast<Rusty{adt}Tag>(reinterpret_cast<const {adt}*>(self)->tag());', file=c_api_impl)
    print('}', file=c_api_impl)
    for variant in variants:
        variant.gen_c_api()

    fname = f'{adt}_new'
    fsig = f'C{adt} (*{fname})(Pool *pool, Rusty{adt}Tag tag)'
    func_table.append((fname, fsig))
    print(f'static C{adt} {fname}(Pool *pool, Rusty{adt}Tag tag) {{', file=c_api_impl)
    print(f'    auto obj = {adt}(static_cast<{adt}Tag>(tag));', file=c_api_impl)
    print(f'    auto cobj = C{adt}{{}};', file=c_api_impl)
    print(f'    std::memcpy(&cobj, &obj, sizeof(C{adt}));', file=c_api_impl)
    print(f'    (void)obj.steal();', file=c_api_impl)
    print(f'    return cobj;', file=c_api_impl)
    print('}', file=c_api_impl)

instructions = [
    Instruction('Buffer'),
    Instruction('Texture2d'),
    Instruction('Texture3d'),
    Instruction('BindlessArray'),
    Instruction('Accel'),
    Instruction('Shared'),
    Instruction('Uniform'),
    Instruction('Argument', [
        ('bool', 'by_value'),
    ]),
    Instruction('Constant', [
        ('const Type*', 'ty'),
        ('luisa::vector<uint8_t>', 'value')
    ],cpp_src='''
[[nodiscard]] uint16_t as_uint16() const noexcept {
    LUISA_ASSERT(ty->is_uint16(), "Type mismatch!");
    return *reinterpret_cast<const uint16_t*>(value.data());
}
[[nodiscard]] uint32_t as_uint32() const noexcept {
    LUISA_ASSERT(ty->is_uint32(), "Type mismatch!");
    return *reinterpret_cast<const uint32_t*>(value.data());
}
[[nodiscard]] uint64_t as_uint64() const noexcept {
    LUISA_ASSERT(ty->is_uint64(), "Type mismatch!");
    return *reinterpret_cast<const uint64_t*>(value.data());
}
[[nodiscard]] int16_t as_int16() const noexcept {
    LUISA_ASSERT(ty->is_int16(), "Type mismatch!");
    return *reinterpret_cast<const int16_t*>(value.data());
}
[[nodiscard]] int32_t as_int32() const noexcept {
    LUISA_ASSERT(ty->is_int32(), "Type mismatch!");
    return *reinterpret_cast<const int32_t*>(value.data());
}
[[nodiscard]] int64_t as_int64() const noexcept {
    LUISA_ASSERT(ty->is_int64(), "Type mismatch!");
    return *reinterpret_cast<const int64_t*>(value.data());
}
[[nodiscard]] half as_float16() const noexcept {
    LUISA_ASSERT(ty->is_float16(), "Type mismatch!");
    return *reinterpret_cast<const half*>(value.data());
}
[[nodiscard]] float as_float32() const noexcept {
    LUISA_ASSERT(ty->is_float32(), "Type mismatch!");
    return *reinterpret_cast<const float*>(value.data());
}
[[nodiscard]] bool as_bool() const noexcept {
    LUISA_ASSERT(ty->is_bool(), "Type mismatch!");
    return *reinterpret_cast<const bool*>(value.data());
}
'''),
    Instruction('Call', [
        ('Func', 'func'),
        ('luisa::vector<Node*>', 'args'),
    ]),
    Instruction('Phi', [('luisa::vector<PhiIncoming>', 'incomings')]),
    Instruction("BasicBlockSentinel", []),
    Instruction('If', [
        ('Node*', 'cond'),
        ('const BasicBlock*', 'true_branch'),
        ('const BasicBlock*', 'false_branch')
    ]),
    Instruction('GenericLoop', [
        ('const BasicBlock*', 'prepare'),
        ('Node*', 'cond'),
        ('const BasicBlock*', 'body'),
        ('const BasicBlock*', 'update')
    ]),
    Instruction('Switch', [
        ('Node*', 'value'),
        ('luisa::vector<SwitchCase>', 'cases'),
        ('const BasicBlock*', 'default_')
    ]),
    Instruction('Local', [
        ('Node*', 'init')
    ]),
    Instruction('Break', []),
    Instruction('Continue', []),
    Instruction('Return', [
        ('Node*', 'value')
    ]),
    Instruction('Print', [
        ('luisa::string', 'fmt'),
        ('luisa::vector<Node*>', 'args')
    ]),
    Instruction('Comment', [
        ('luisa::string', 'comment')
    ]),
    Instruction('Update', [
        ('Node*', 'var'),
        ('Node*', 'value')
    ]),
    Instruction('RayQuery', [
        ('Node*', 'query'),
        ('const BasicBlock*', 'on_triangle_hit'),
        ('const BasicBlock*', 'on_procedural_hit'),
    ]),
    Instruction('RevAutodiff', [
        ('const BasicBlock*', 'body'),
    ]),
    Instruction('FwdAutodiff', [
        ('const BasicBlock*', 'body'),
    ]),
]

funcs = [
    Func('Undef'),
    Func('Zero', []),
    Func('One', []),

    Func('Assume', [
        ('luisa::string', 'msg')
    ]),
    Func('Unreachable', [
         ('luisa::string', 'msg')
    ]),
    Func('Assert', [('luisa::string', 'msg')]),
    Func('ThreadId', []),
    Func('BlockId', []),
    Func('WarpSize', []),
    Func('WarpLaneId', []),
    Func('DispatchId', []),
    Func('DispatchSize', []),

    Func('PropagateGradient', [], side_effects=True),
    Func('OutputGradient', []),

    Func('RequiresGradient', [], side_effects=True),
    Func('Backward', [], comment='//Backward(out, out_grad)', side_effects=True),
    Func('Gradient', []),
    Func('AccGrad', [], side_effects=True),
    Func('Detach', []),

    Func('RayTracingInstanceTransform'),
    Func('RayTracingInstanceVisibilityMask'),
    Func('RayTracingInstanceUserId'),
    Func('RayTracingSetInstanceTransform', side_effects=True),
    Func('RayTracingSetInstanceOpacity', side_effects=True),
    Func('RayTracingSetInstanceVisibility', side_effects=True),
    Func('RayTracingSetInstanceUserId', side_effects=True),

    Func('RayTracingTraceClosest', []),
    Func('RayTracingTraceAny', []),
    Func('RayTracingQueryAll', []),
    Func('RayTracingQueryAny', []),
    Func('RayQueryWorldSpaceRay', []),
    Func('RayQueryProceduralCandidateHit', []),
    Func('RayQueryTriangleCandidateHit', []),
    Func('RayQueryCommittedHit', []),
    Func('RayQueryCommitTriangle', [], side_effects=True),
    Func('RayQueryCommitProcedural', [], side_effects=True),
    Func('RayQueryTerminate', [], side_effects=True),

    Func('Load', []),

    Func('Cast', []),
    Func('BitCast', []),

    Func('Add'),
    Func('Sub'),
    Func('Mul'),
    Func('Div'),
    Func('Rem'),
    Func('BitAnd'),
    Func('BitOr'),
    Func('BitXor'),
    Func('Shl'),
    Func('Shr'),
    Func('RotRight'),
    Func('RotLeft'),
    Func('Eq'),
    Func('Ne'),
    Func('Lt'),
    Func('Le'),
    Func('Gt'),
    Func('Ge'),
    Func('MatCompMul'),

    Func('Neg'),
    Func('Not'),
    Func('BitNot'),

    Func('All', []),
    Func('Any', []),

    Func('Select'),
    Func('Clamp'),
    Func('Lerp'),
    Func('Step'),
    Func('Saturate'),
    Func('SmoothStep'),

    Func('Abs'),
    Func('Min'),
    Func('Max'),

    Func('ReduceSum'),
    Func('ReduceProd'),
    Func('ReduceMin'),
    Func('ReduceMax'),
    Func('Clz'),
    Func('Ctz'),
    Func('PopCount'),
    Func('Reverse'),
    Func('IsInf'),
    Func('IsNan'),
    Func('Acos'),
    Func('Acosh'),
    Func('Asin'),
    Func('Asinh'),
    Func('Atan'),
    Func('Atan2'),
    Func('Atanh'),
    Func('Cos'),
    Func('Cosh'),
    Func('Sin'),
    Func('Sinh'),
    Func('Tan'),
    Func('Tanh'),
    Func('Exp'),
    Func('Exp2'),
    Func('Exp10'),
    Func('Log'),
    Func('Log2'),
    Func('Log10'),
    Func('Powi'),
    Func('Powf'),
    Func('Sqrt'),
    Func('Rsqrt'),
    Func('Ceil'),
    Func('Floor'),
    Func('Fract'),
    Func('Trunc'),
    Func('Round'),
    Func('Fma'),
    Func('Copysign'),
    Func('Cross'),
    Func('Dot'),
    Func('OuterProduct'),
    Func('Length'),
    Func('LengthSquared'),
    Func('Normalize'),
    Func('Faceforward'),
    Func('Distance'),
    Func('Reflect'),
    Func('Determinant'),
    Func('Transpose'),
    Func('Inverse'),

    Func('WarpIsFirstActiveLane', side_effects=True),
    Func('WarpFirstActiveLane', side_effects=True),
    Func('WarpActiveAllEqual', side_effects=True),
    Func('WarpActiveBitAnd', side_effects=True),
    Func('WarpActiveBitOr', side_effects=True),
    Func('WarpActiveBitXor', side_effects=True),
    Func('WarpActiveCountBits', side_effects=True),
    Func('WarpActiveMax', side_effects=True),
    Func('WarpActiveMin', side_effects=True),
    Func('WarpActiveProduct', side_effects=True),
    Func('WarpActiveSum', side_effects=True),
    Func('WarpActiveAll', side_effects=True),
    Func('WarpActiveAny', side_effects=True),
    Func('WarpActiveBitMask', side_effects=True),
    Func('WarpPrefixCountBits', side_effects=True),
    Func('WarpPrefixSum', side_effects=True),
    Func('WarpPrefixProduct', side_effects=True),
    Func('WarpReadLaneAt', side_effects=True),
    Func('WarpReadFirstLane', side_effects=True),
    Func('SynchronizeBlock', side_effects=True),

    Func('AtomicExchange', [], side_effects=True),
    Func('AtomicCompareExchange', [], side_effects=True),
    Func('AtomicFetchAdd', [], side_effects=True),
    Func('AtomicFetchSub', [], side_effects=True),
    Func('AtomicFetchAnd', [], side_effects=True),
    Func('AtomicFetchOr', [], side_effects=True),
    Func('AtomicFetchXor', [], side_effects=True),
    Func('AtomicFetchMin', [], side_effects=True),
    Func('AtomicFetchMax', [], side_effects=True),

    Func('BufferWrite', [], side_effects=True),
    Func('BufferRead', []),
    Func('BufferSize', []),

    Func('ByteBufferWrite', [], side_effects=True),
    Func('ByteBufferRead', []),
    Func('ByteBufferSize', []),

    Func('Texture2dRead'),
    Func('Texture2dWrite', [], side_effects=True),
    Func('Texture2dSize'),

    Func('Texture3dRead'),
    Func('Texture3dWrite', [], side_effects=True),
    Func('Texture3dSize'),

    Func('BindlessTexture2dSample', []),
    Func('BindlessTexture2dSampleLevel', []),
    Func('BindlessTexture2dSampleGrad', []),
    Func('BindlessTexture2dSampleGradLevel', []),
    Func('BindlessTexture2dRead', []),
    Func('BindlessTexture2dReadLevel', []),
    Func('BindlessTexture2dSize', []),
    Func('BindlessTexture2dSizeLevel', []),

    Func('BindlessTexture3dSample', []),
    Func('BindlessTexture3dSampleLevel', []),
    Func('BindlessTexture3dSampleGrad', []),
    Func('BindlessTexture3dSampleGradLevel', []),
    Func('BindlessTexture3dRead', []),
    Func('BindlessTexture3dReadLevel', []),
    Func('BindlessTexture3dSize', []),
    Func('BindlessTexture3dSizeLevel', []),

    Func('BindlessBufferWrite', [], side_effects=True),
    Func('BindlessBufferRead', []),
    Func('BindlessBufferSize', []),
    Func('BindlessBufferType'),

    Func('BindlessByteBufferWrite', [], side_effects=True),
    Func('BindlessByteBufferRead', []),
    Func('BindlessByteBufferSize', []),

    Func('Vec', []),
    Func('Vec2', []),
    Func('Vec3', []),
    Func('Vec4', []),

    Func('Permute'),

    Func('GetElementPtr', []),
    Func('ExtractElement', []),
    Func('InsertElement', []),

    Func('Array'),
    Func('Struct'),

    Func('MatFull', []),
    Func('Mat2', []),
    Func('Mat3', []),
    Func('Mat4', []),

    Func('BindlessAtomicExchange', [
        ('const Type*', 'ty')], side_effects=True),
    Func('BindlessAtomicCompareExchange', [
        ('const Type*', 'ty')], side_effects=True),
    Func('BindlessAtomicFetchAdd', [
         ('const Type*', 'ty')], side_effects=True),
    Func('BindlessAtomicFetchSub', [
         ('const Type*', 'ty')], side_effects=True),
    Func('BindlessAtomicFetchAnd', [
         ('const Type*', 'ty')], side_effects=True),
    Func('BindlessAtomicFetchOr', [
        ('const Type*', 'ty')], side_effects=True),
    Func('BindlessAtomicFetchXor', [
         ('const Type*', 'ty')], side_effects=True),
    Func('BindlessAtomicFetchMin', [
         ('const Type*', 'ty')], side_effects=True),
    Func('BindlessAtomicFetchMax', [
         ('const Type*', 'ty')], side_effects=True),

    Func('Callable', [
        ('luisa::shared_ptr<CallableModule>', 'module'),
    ]),
    Func('CpuExt', [
        ('luisa::shared_ptr<CpuExternFn>', 'f'),
    ]),
    Func('ShaderExecutionReorder')
]

print('''
struct PhiIncoming {
    BasicBlockRef block = nullptr;
    NodeRef value = nullptr;
};
struct SwitchCase {
    int32_t value = 0;
    BasicBlockRef block = nullptr;    
};
struct CpuExternFnData {
    void *data = nullptr;
    void (*func)(void *data, void *args) = nullptr;
    void (*dtor)(void *data) = nullptr;
    TypeRef arg_ty = nullptr;
};
struct CpuExternFn;
struct FuncMetadata {
    bool has_side_effects = false;    
};
const FuncMetadata* func_metadata();
''', file=fwd_file)
gen_adt("Func", '''    [[nodiscard]] const FuncMetadata& metadata() const noexcept;
    [[nodiscard]] bool has_side_effects() const noexcept { return metadata().has_side_effects; }
    Func():Func(Tag::UNDEF){}
''', funcs)
gen_adt("Instruction", "", instructions)

def gen_func_metadata():
    print('static FuncMetadata _func_metadata[] = {', file=c_api_impl)
    for f in funcs:
        print('    {{ {} }},'.format('true' if f.side_effects else 'false'), file=c_api_impl)
    print('};', file=c_api_impl)
    print('static_assert(sizeof(_func_metadata) == sizeof(FuncMetadata) * {});'.format(len(funcs)), file=c_api_impl)
    print('const FuncMetadata* func_metadata() { return _func_metadata; }', file=c_api_impl)

    print('const FuncMetadata& Func::metadata() const noexcept {', file=cpp_api_impl)
    print('    return func_metadata()[static_cast<int>(tag())];', file=cpp_api_impl)
    print('}', file=cpp_api_impl)

    # add func_metadata to binding table
    fname = f'func_metadata'
    fsig = f'const FuncMetadata* (*{fname})()'
    func_table.append((fname, fsig))
    
gen_func_metadata()

bindings = [
    Item('BufferBinding','Binding', [
        ('uint64_t', 'handle'),
        ('uint64_t', 'offset'),
        ('uint64_t', 'size')
    ]),
    Item('TextureBinding', 'Binding',[
        ('uint64_t', 'handle'),
        ('uint64_t', 'level'),
    ]),
    Item('BindlessArrayBinding', 'Binding',[
        ('uint64_t', 'handle'),
    ]),
    Item('AccelBinding','Binding',[
        ('uint64_t', 'handle'),
    ]),
]
gen_adt('Binding', '', bindings)

def gen_extra_bindings():
    def add_func(ret, name, args):
        fname = f'ir_v2_binding_{name}'
        fsig = f'{ret} (*{name})({args})'
        func_table.append((fname, fsig))
        print(f'{ret} {fname}({args});', file=fwd_file)
    add_func('const Type*', 'type_extract', 'const Type* ty, uint32_t index')
    add_func('size_t', 'type_size', 'const Type* ty')
    add_func('size_t', 'type_alignment', 'const Type* ty')
    add_func('RustyTypeTag', 'type_tag', 'const Type* ty')
    add_func('bool', 'type_is_scalar', 'const Type* ty')
    add_func('bool', 'type_is_bool', 'const Type* ty')
    add_func('bool', 'type_is_int16', 'const Type* ty')
    add_func('bool', 'type_is_int32', 'const Type* ty')
    add_func('bool', 'type_is_int64', 'const Type* ty')
    add_func('bool', 'type_is_uint16', 'const Type* ty')
    add_func('bool', 'type_is_uint32', 'const Type* ty')
    add_func('bool', 'type_is_uint64', 'const Type* ty')
    add_func('bool', 'type_is_float16', 'const Type* ty')
    add_func('bool', 'type_is_float32', 'const Type* ty')

    add_func('bool', 'type_is_array', 'const Type* ty')
    add_func('bool', 'type_is_vector', 'const Type* ty')
    add_func('bool', 'type_is_struct', 'const Type* ty')
    add_func('bool', 'type_is_custom', 'const Type* ty')
    add_func('bool', 'type_is_matrix', 'const Type* ty')
    
    add_func('const Type*', 'type_element', 'const Type* ty')
    add_func('Slice<const char>', 'type_description', 'const Type* ty')
    add_func('size_t', 'type_dimension', 'const Type* ty')
    add_func('Slice<const Type* const>', 'type_members', 'const Type* ty')

    add_func('const Type*', 'make_struct', 'size_t alignment, const Type** tys, uint32_t count')
    add_func('const Type*', 'make_array', 'const Type* ty, uint32_t count')
    add_func('const Type*', 'make_vector', 'const Type* ty, uint32_t count')
    add_func('const Type*', 'make_matrix', 'uint32_t dim')
    add_func('const Type*', 'make_custom', 'Slice<const char> name')
    add_func('const Type*', 'from_desc', 'Slice<const char> desc')

    add_func('const Type*', 'type_bool', '')
    add_func('const Type*', 'type_int16', '')
    add_func('const Type*', 'type_int32', '')
    add_func('const Type*', 'type_int64', '')
    add_func('const Type*', 'type_uint16', '')
    add_func('const Type*', 'type_uint32', '')
    add_func('const Type*', 'type_uint64', '')
    add_func('const Type*', 'type_float16', '')
    add_func('const Type*', 'type_float32', '')

    add_func('const Node*', 'node_prev', 'const Node* node')
    add_func('const Node*', 'node_next', 'const Node* node')
    add_func('const CInstruction*', 'node_inst', 'const Node* node')
    add_func('const Type*', 'node_type', 'const Node* node')
    add_func('int32_t', 'node_get_index', 'const Node* node')

    add_func('const Node*', 'basic_block_first', 'const BasicBlock* block')
    add_func('const Node*', 'basic_block_last', 'const BasicBlock* block')

    add_func('void','node_unlink', 'Node* node')
    add_func('void','node_set_next', 'Node* node, Node* next')
    add_func('void','node_set_prev', 'Node* node, Node* prev')
    add_func('void','node_replace', 'Node* node, Node* new_node')

    add_func('Pool*', 'pool_new', '')
    add_func('void', 'pool_drop', 'Pool* pool')
    add_func('Pool*', 'pool_clone', 'Pool* pool')

    add_func('IrBuilder*', 'ir_builder_new', 'Pool* pool')
    add_func('IrBuilder*', 'ir_builder_new_without_bb', 'Pool* pool')
    add_func('void', 'ir_builder_drop', 'IrBuilder* builder')
    add_func('void', 'ir_builder_set_insert_point', 'IrBuilder* builder, Node* node')
    add_func('Node *', 'ir_builder_insert_point', 'IrBuilder* builder')
    add_func('Node *','ir_build_call', 'IrBuilder* builder, CFunc &&func, Slice<const Node* const> args, const Type* ty')
    add_func('Node *','ir_build_call_tag', 'IrBuilder* builder, RustyFuncTag tag, Slice<const Node* const> args, const Type* ty')
    add_func('Node *', 'ir_build_if', 'IrBuilder* builder, const Node* cond, const BasicBlock* true_branch, const BasicBlock* false_branch')
    add_func('Node *', 'ir_build_generic_loop', 'IrBuilder* builder, const BasicBlock* prepare, const Node* cond, const BasicBlock* body, const BasicBlock* update')
    add_func('Node *', 'ir_build_switch', 'IrBuilder* builder, const Node* value, Slice<const SwitchCase> cases, const BasicBlock* default_')
    add_func('Node *', 'ir_build_local', 'IrBuilder* builder, const Node* init')
    add_func('Node *', 'ir_build_break', 'IrBuilder* builder')
    add_func('Node *', 'ir_build_continue', 'IrBuilder* builder')
    add_func('Node *', 'ir_build_return', 'IrBuilder* builder, const Node* value')
    add_func('const BasicBlock*', 'ir_builder_finish', 'IrBuilder&& builder')

    add_func('const CpuExternFnData*', 'cpu_ext_fn_data', 'const CpuExternFn* f')
    add_func('const CpuExternFn*', 'cpu_ext_fn_new', 'CpuExternFnData')
    add_func('const CpuExternFn*', 'cpu_ext_fn_clone', 'const CpuExternFn* f')
    add_func('void', 'cpu_ext_fn_drop', 'const CpuExternFn* f')


gen_extra_bindings()

# generate binding table
print('struct IrV2BindingTable {', file=c_def)
for f in func_table:
    fname, fsig = f
    print('    {};'.format(fsig), file=c_def)
print('};', file=c_def)
print('extern "C" LC_IR_API IrV2BindingTable lc_ir_v2_binding_table();', file=c_def)

# generate binding table impl
print(
    'extern "C" LC_IR_API IrV2BindingTable lc_ir_v2_binding_table() {', file=c_api_impl)
print('    return {', file=c_api_impl)
for f in func_table:
    fname, fsig = f
    print('        {},'.format(fname), file=c_api_impl)
print('    };', file=c_api_impl)
print('}', file=c_api_impl)


print('}', file=cpp_def)
print('}', file=fwd_file)
print('}', file=c_def)
print('}', file=c_api_impl)
print('}', file=cpp_api_impl)

cpp_def.close()
fwd_file.close()
c_def.close()
c_api_impl.close()
cpp_api_impl.close()

# run clang-format
os.system('clang-format -i ir_v2_defs.h')
os.system('clang-format -i ir_v2_fwd.h')
os.system('clang-format -i ir_v2_api.h')
os.system('clang-format -i ../../../src/ir_v2/ir_v2_api.cpp')
os.system('clang-format -i ../../../src/ir_v2/ir_v2_defs.cpp')

# run bindgen
os.system('bindgen ir_v2_api.h -o ../../../src/rust/luisa_compute_ir_v2/src/binding.rs --rustified-enum .*Tag --disable-name-namespacing '
          '--blocklist-type _.* --blocklist-function _.* --blocklist-item _.* --blocklist-function .* '
          '--blocklist-type TypeTag --blocklist-type InstructionTag --blocklist-type FuncTag --blocklist-type BindingTag '
          '--new-type-alias .*Ref --new-type-alias .*RefMut '
          '--with-derive-partialeq --with-derive-eq --with-derive-hash '
          '-- -I../../ -x c++ -std=c++17 -DLC_IR_EXPORT_DLL=1 -DBINDGEN -Wno-pragma-once-outside-header -Wno-return-type-c-linkage')



    
